#!/bin/bash

# ------------------------------ Configuration ------------------------------
# Paths to datasets (brand-separated and time-separated)
BRAND_DATASET="$(dirname "$0")/dataset/test_tweet_comp_be.jsonl"
TIME_DATASET="$(dirname "$0")/dataset/test_tweet_ran_be.jsonl"

# Parent output directory (all results stored here)
OUTPUT_PARENT="twitter_gpt_results"
BRAND_OUT_DIR="$OUTPUT_PARENT/brand_separated"
TIME_OUT_DIR="$OUTPUT_PARENT/time_separated"

# Python evaluation script
SCRIPT_PATH="$(dirname "$0")/gpt_nopersonaten.py"

# GPUs to use (same as before)
GPUS=(0 1 2 3 4 5 6 7)      # 7 GPUs
MAX_JOBS_PER_GPU=1        # At most 2 concurrent jobs per GPU
TOTAL_JOBS=$((${#GPUS[@]} * MAX_JOBS_PER_GPU))

# Create directories
mkdir -p "$BRAND_OUT_DIR" "$TIME_OUT_DIR" "${OUTPUT_PARENT}/logs"

# ---------------------------------------------------------------------------
# Helper to launch slice jobs for a given dataset
#   $1 : dataset path
#   $2 : output dir
launch_jobs() {
  local DATA_PATH="$1"
  local OUT_DIR="$2"

  local TOTAL_LINES
  TOTAL_LINES=$(wc -l < "$DATA_PATH")
  echo "Dataset: $DATA_PATH  (lines=$TOTAL_LINES)"

  # Compute chunk size
  local DATA_PER_JOB=$(( (TOTAL_LINES + TOTAL_JOBS - 1) / TOTAL_JOBS ))
  echo "Splitting into chunks of $DATA_PER_JOB lines across $TOTAL_JOBS jobs"

  local job_counter=0 start_idx=0
  while [ $start_idx -lt $TOTAL_LINES ]; do
    local end_idx=$((start_idx + DATA_PER_JOB - 1))
    if [ $end_idx -ge $TOTAL_LINES ]; then
      end_idx=$((TOTAL_LINES - 1))
    fi

    # Select GPU in round-robin fashion
    local gpu_idx=$((job_counter % ${#GPUS[@]}))
    local gpu_id=${GPUS[$gpu_idx]}

    echo "  Launching $start_idx-$end_idx on GPU $gpu_id"
    CUDA_VISIBLE_DEVICES=$gpu_id nohup python -u "$SCRIPT_PATH" \
      --tweet_eval \
      --dataset_paths "$DATA_PATH" \
      --start "$start_idx" \
      --end "$end_idx" \
      --output_dir "$OUT_DIR" \
      --gpu_id 0 \
      > "${OUTPUT_PARENT}/logs/job_${start_idx}_${end_idx}_gpu${gpu_id}.log" 2>&1 &

    ((job_counter++))
    start_idx=$((end_idx + 1))

    # Throttle: wait when current batch size equals TOTAL_JOBS
    if [ $((job_counter % TOTAL_JOBS)) -eq 0 ]; then
      echo "Waiting for current batch of $TOTAL_JOBS jobs to finish..."
      wait
      echo "Resuming launches..."
    fi
  done
  # Wait for any remaining background jobs
  wait

  # ---------------- Merge partial JSON outputs ----------------
  echo "Merging partial outputs in $OUT_DIR ..."
  python - <<PY
import glob, json, os, sys
out_dir = '$OUT_DIR'
parent = '$OUTPUT_PARENT'
pattern = os.path.join(out_dir, 'tweet_results_*.json')
merged = []
for fp in sorted(glob.glob(pattern)):
    with open(fp, 'r') as f:
        merged.extend(json.load(f))
merged_name = os.path.basename(out_dir) + '_merged.json'
merge_path = os.path.join(parent, merged_name)
with open(merge_path, 'w') as f:
    json.dump(merged, f, indent=2)
print(f'Merged file written to {merge_path} (records={len(merged)})')
PY
}

# ------------------------------ Main ----------------------------------------
launch_jobs "$BRAND_DATASET" "$BRAND_OUT_DIR"
launch_jobs "$TIME_DATASET"  "$TIME_OUT_DIR"

echo "All tasks completed. Final merged JSONs are in $OUTPUT_PARENT" 